# ------------------------------------------------------------------------------
# Plots performance of limited bandwidth models
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=10000]" bash 06_plot_performance.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(tidyverse)
library(glue)
library(ggthemes)

temp <- here(
  "code", "06_physician_boundedness", "01_behavioral_lasso", "temp"
)
a <- modules::use(here("lib", "aesthetics.R"))
a$get_font("Optima", here("lib", "optima.ttf"))

# Helpers ----------------------------------------------------------------------
log_step <- function(x, step_size = 0.25) {
  log_x <- sort(unique(log10(x)))

  ii <- 2

  while (ii <= length(log_x)) {
    if (log_x[ii] - log_x[ii - 1] < step_size) {
      log_x <- log_x[-ii]
    } else {
      ii <- ii + 1
    }
  }

  log_x <- c(log_x, max(sort(unique(log10(x)))))

  10^log_x
}

name_performance <- function(measure_name, score_name, target_name) {
  score_name <- str_replace(score_name, "010_day", "hat")
  target_name <- str_remove(target_name, "_010_day")

  glue::glue("{measure_name}({target_name}, {score_name})")
}
label_performance <- function(measure_name, score_name, target_name) {
  label <- case_when(
    grepl("stent", target_name) ~ "Yield of Testing",
    grepl("test", target_name) ~ "Testing Decision",
    TRUE ~ "MACE in Untested"
  )
}

name_plot <- function(measure_name, dir_name = temp, zoom = FALSE) {
  zoom_lab <- ifelse(zoom, "_zoomed", "")
  file.path(dir_name, glue("BEHAVIORAL_{toupper(measure_name)}{zoom_lab}.png"))
}

plot_performance <- function(which_measure, tb) {
  gg_tb <- filter(tb, measure_name == which_measure)

  best_n_df <- gg_tb %>%
    group_by(performance_label) %>%
    summarize(
      best_performance = max(performance),
      best_n_coef = n_coef[which(performance == best_performance)]
    ) %>%
    ungroup()
  best_n <- best_n_df$best_n_coef
  names(best_n) <- best_n_df$performance_label

  color_key <- a$disc_palette[1:3]
  names(color_key) <- c("Yield of Testing", "Testing Decision", "MACE in Untested")

  y_lab <- switch(which_measure,
    auc = "AUC",
    r2 = "R2",
    rss = "Residual Sum of Squares",
    ppv = "Positive Predictive Value"
  )

  legend_pos <- ifelse(
    which_measure == "r2", "bottom", "none"
  )

  gg <- ggplot(
    gg_tb,
    aes(
      x = n_coef,
      y = performance,
      # ymin = performance + delta_lo,
      # ymax = performance + delta_hi,
      color = performance_label,
      fill = performance_label
    )
  ) +
    labs(
      x = "Number of Variables",
      y = y_lab,
      # caption = glue("95% confidence intervals from 1,000 bootstrap iterations."),
      color = glue("Predictive accuracy ({y_lab}) of regularized model")
    ) +
    scale_x_continuous(
      trans = "log10",
      breaks = log_step(gg_tb$n_coef),
      minor_breaks = NULL
    ) +
    scale_color_manual(values = color_key, aesthetics = c("color", "fill")) +
    guides(fill = FALSE) +
    theme_bw() +
    # geom_ribbon(
    #   alpha = 0.4,
    #   color = NA
    # ) +
    geom_line(linetype = "dashed") +
    geom_point() +
    theme(
      legend.position = legend_pos,
      text = element_text(family = "Optima", size = 40)
    ) +
    guides(col = guide_legend(nrow = 2))

  lab_height_base <- max(gg_tb$performance) + 0.1
  for (grp in names(best_n)) {
    outcome <- ifelse(grepl("Yield", grp), "Yield", "Test")
    # lab <- glue("Maximum Accuracy,\n{outcome} prediction:\nk = {best_n[grp]}")
    lab_height <- lab_height_base
    if (outcome == "Test") {
      lab_height <- lab_height - 0.05
    }
    complexity <- ifelse(grepl("Yield", grp), "Complex", "Simple")
    lab <- glue("Maximum Performance: {outcome}\nk = {best_n[grp]}")
    gg <- gg +
      geom_vline(
        xintercept = best_n[grp], color = color_key[grp]
      ) +
      annotate("label",
        x = best_n[grp], y = lab_height, label = lab,
        size = 14, lineheight = 0.25, fill = color_key[grp], alpha = 0.4,
        hjust = 0, family = "Optima"
      )
  }

  return(gg)
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
performance_obs_raw <- readRDS(file.path(temp, "performance_obs.rds"))
performance_boot_raw <- readRDS(file.path(temp, "performance_boot.rds"))

# ------------------------------------------------------------------------------
performance_boot_delta <- performance_boot_raw %>%
  rename(perf_boot = performance) %>%
  inner_join(performance_obs_raw, by = c("lambda", "n_coef", "score_name", "target_name", "measure_name")) %>%
  mutate(delta = perf_boot - performance) %>%
  group_by(lambda, n_coef, score_name, target_name, measure_name) %>%
  summarize(
    # delta_lo = quantile(delta, 0.025),
    # delta_hi = quantile(delta, 0.975)
  ) %>%
  ungroup()

performance_tb <- performance_obs_raw %>%
  mutate(
    performance_name = pmap_chr(
      list(measure_name, score_name, target_name), name_performance
    )
  ) %>%
  mutate(
    performance_label = pmap_chr(
      list(measure_name, score_name, target_name), label_performance
    )
  ) %>%
  filter(
    target_name == "test_010_day" | target_name == "stent_or_cabg_010_day",
    score_name == "stent_or_cabg_010_day",
    # model_name == "lasso"
    model_name == "glm"
  )

design <- tibble(which_measure = c("auc", "r2"))

plots <- transmute(
  design,
  filename = name_plot(which_measure),
  plot = map(
    which_measure, plot_performance,
    tb = performance_tb
  ),
)

pwalk(plots, ggsave, width = 10, height = 7, units = "in")

message("Done.")
